import pennylane as qml
from pennylane.ops import CNOT, RX, RY, RZ, CZ
from pennylane import numpy as np
import numpy.linalg as la
import math
from math import pi
from datetime import datetime
import os
import random


def f_n(weights, ansatz=None):
    return np.sum(ansatz(weights))


def gd_optimizer_s2fixed(ansatz, weights, s2mask, noise_gamma, lr, iteration, n_check):
    grad_func = qml.grad(f_n)
    n_params = len(weights)
    grad_norm = np.zeros(iteration)
    loss = np.zeros(iteration)

    t1 = datetime.now()
    for it in range(iteration):
        weights.requires_grad = True

        loss[it] = f_n(weights, ansatz=ansatz)
        grad_now = grad_func(weights, ansatz=ansatz)
        grad_norm[it] = la.norm(grad_now)


        grad_now[s2mask] = 0.0
        noise = np.random.normal(0, noise_gamma, n_params)
        weights = weights - lr * (grad_now + noise)

        if it % n_check == 0:
            t2 = datetime.now()
            print(f"[GD-Fixed] iter={it}, loss={loss[it]:.6f}, grad={grad_norm[it]:.6f}, "
                  f"freeze={np.sum(s2mask)}, time={(t2 - t1).seconds}s")
            t1 = t2

    return loss, grad_norm, weights

def adam_optimizer_s2fixed(ansatz, weights, s2mask, noise_gamma, lr, iteration, n_check):
    beta_1, beta_2, eps = 0.9, 0.99, 1e-8
    grad_func = qml.grad(f_n)
    n = len(weights)
    m = np.zeros(n); v = np.zeros(n)
    grad_norm = np.zeros(iteration); loss = np.zeros(iteration)

    t1 = datetime.now()
    for it in range(iteration):
        weights.requires_grad = True
        loss[it] = f_n(weights, ansatz=ansatz)
        g = grad_func(weights, ansatz=ansatz)
        grad_norm[it] = la.norm(g)

        g[s2mask] = 0.0
        g += np.random.normal(0, noise_gamma, n)

        m = beta_1 * m + (1 - beta_1) * g
        v = beta_2 * v + (1 - beta_2) * (g**2)
        m_hat = m / (1 - beta_1**(it + 1))
        v_hat = v / (1 - beta_2**(it + 1))
        weights = weights - lr * m_hat / (np.sqrt(v_hat) + eps)

        if it % n_check == 0:
            t2 = datetime.now()
            print(f"[Adam-Fixed] iter={it}, loss={loss[it]:.6f}, grad={grad_norm[it]:.6f}, "
                  f"freeze={np.sum(s2mask)}, time={(t2 - t1).seconds}s")
            t1 = t2

    return loss, grad_norm, weights


def _run_and_save(folder, ansatz, init_vecs, s2mask, noise_gamma,
                  lr, iteration, n_check, opt_name):
    if not os.path.exists(folder):
        os.makedirs(folder)
    for idx, init in enumerate(init_vecs):
        np.save(f"{folder}/weights_init_{idx}.npy", init)
        loss, grad_norm, w_final = gd_optimizer_s2fixed(
            ansatz, init, s2mask, noise_gamma, lr, iteration, n_check)
        np.save(f"{folder}/loss_{idx}.npy", loss)
        np.save(f"{folder}/grad_norm_{idx}.npy", grad_norm)
        np.save(f"{folder}/weights_final_{idx}.npy", w_final)
        np.save(f"{folder}/energy_final_{idx}.npy", np.array([loss[-1]]))
        print(f"[{opt_name}] run={idx}, final loss={loss[-1]:.6f}")

def gd_gaussian(ansatz, circuit_name, qubits, n_params, gamma, gau_rate,
                noise_gamma, noise_rate, lr, iteration, n_time, n_check, s2mask_source):
    folder = f"./{circuit_name}_{qubits}_{n_params}_gd_gaussian_{gau_rate}_{noise_rate}_frozen"
    init_vecs = np.random.normal(0, gamma * gau_rate, n_params * n_time).reshape(n_time, n_params)
    _run_and_save(folder, ansatz, init_vecs, s2mask_source, noise_gamma * noise_rate,
                  lr, iteration, n_check, "GD-Gaussian")

def gd_zero(ansatz, circuit_name, qubits, n_params, *_,
            noise_gamma, noise_rate, lr, iteration, n_time, n_check, s2mask_source):
    folder = f"./{circuit_name}_{qubits}_{n_params}_gd_zero_{noise_rate}_frozen"
    init_vecs = np.zeros((n_time, n_params))
    _run_and_save(folder, ansatz, init_vecs, s2mask_source, noise_gamma * noise_rate,
                  lr, iteration, n_check, "GD-Zero")

def gd_uniform(ansatz, circuit_name, qubits, n_params, *_,
               noise_gamma, noise_rate, lr, iteration, n_time, n_check, s2mask_source):
    folder = f"./{circuit_name}_{qubits}_{n_params}_gd_uniform_{noise_rate}_frozen"
    init_vecs = np.random.uniform(0, 2 * math.pi, size=n_time * n_params).reshape(n_time, n_params)
    _run_and_save(folder, ansatz, init_vecs, s2mask_source, noise_gamma * noise_rate,
                  lr, iteration, n_check, "GD-Uniform")
